# pylint: disable=g-bad-file-header
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Functions to load bsuite environments."""

from typing import Any, Mapping, Tuple

from bsuite import sweep
from bsuite.environments import base
from bsuite.experiments.bandit import bandit
from bsuite.experiments.bandit_noise import bandit_noise
from bsuite.experiments.bandit_scale import bandit_scale
from bsuite.experiments.cartpole import cartpole
from bsuite.experiments.cartpole_noise import cartpole_noise
from bsuite.experiments.cartpole_scale import cartpole_scale
from bsuite.experiments.cartpole_swingup import cartpole_swingup
from bsuite.experiments.catch import catch
from bsuite.experiments.catch_noise import catch_noise
from bsuite.experiments.catch_scale import catch_scale
from bsuite.experiments.deep_sea import deep_sea
from bsuite.experiments.deep_sea_stochastic import deep_sea_stochastic
from bsuite.experiments.discounting_chain import discounting_chain
from bsuite.experiments.memory_len import memory_len
from bsuite.experiments.memory_size import memory_size
from bsuite.experiments.mnist import mnist
from bsuite.experiments.mnist_noise import mnist_noise
from bsuite.experiments.mnist_scale import mnist_scale
from bsuite.experiments.mountain_car import mountain_car
from bsuite.experiments.mountain_car_noise import mountain_car_noise
from bsuite.experiments.mountain_car_scale import mountain_car_scale
from bsuite.experiments.umbrella_distract import umbrella_distract
from bsuite.experiments.umbrella_length import umbrella_length

from bsuite.logging import csv_logging
from bsuite.logging import terminal_logging

import dm_env
import termcolor

# Internal imports.

# Mapping from experiment name to environment constructor or load function.
# Each constructor or load function accepts keyword arguments as defined in
# each experiment's sweep.py file.
EXPERIMENT_NAME_TO_ENVIRONMENT = dict(
    bandit=bandit.load,
    bandit_noise=bandit_noise.load,
    bandit_scale=bandit_scale.load,
    cartpole=cartpole.load,
    cartpole_noise=cartpole_noise.load,
    cartpole_scale=cartpole_scale.load,
    cartpole_swingup=cartpole_swingup.CartpoleSwingup,
    catch=catch.load,
    catch_noise=catch_noise.load,
    catch_scale=catch_scale.load,
    deep_sea=deep_sea.load,
    deep_sea_stochastic=deep_sea_stochastic.load,
    discounting_chain=discounting_chain.load,
    memory_len=memory_len.load,
    memory_size=memory_size.load,
    mnist=mnist.load,
    mnist_noise=mnist_noise.load,
    mnist_scale=mnist_scale.load,
    mountain_car=mountain_car.load,
    mountain_car_noise=mountain_car_noise.load,
    mountain_car_scale=mountain_car_scale.load,
    umbrella_distract=umbrella_distract.load,
    umbrella_length=umbrella_length.load,
)


def unpack_bsuite_id(bsuite_id: str) -> Tuple[str, int]:
  """Returns the experiment name and setting index given a bsuite_id."""
  parts = bsuite_id.split(sweep.SEPARATOR)
  assert len(parts) == 2
  experiment_name = parts[0]
  setting_index = int(parts[1])
  return experiment_name, setting_index


def load(
    experiment_name: str,
    kwargs: Mapping[str, Any],
) -> base.Environment:
  """Returns a bsuite environment given an experiment name and settings."""
  return EXPERIMENT_NAME_TO_ENVIRONMENT[experiment_name](**kwargs)


def load_from_id(bsuite_id: str) -> base.Environment:
  """Returns a bsuite environment given a bsuite_id."""
  kwargs = sweep.SETTINGS[bsuite_id]
  experiment_name, _ = unpack_bsuite_id(bsuite_id)
  env = load(experiment_name, kwargs)
  termcolor.cprint(
      f'Loaded bsuite_id: {bsuite_id}.', color='white', attrs=['bold'])
  return env


def load_and_record(bsuite_id: str,
                    save_path: str,
                    logging_mode: str = 'csv',
                    overwrite: bool = False) -> dm_env.Environment:
  """Returns a bsuite environment wrapped with CSV or terminal logging."""
  if logging_mode == 'csv':
    return load_and_record_to_csv(bsuite_id, save_path, overwrite)
  elif logging_mode == 'terminal':
    return load_and_record_to_terminal(bsuite_id)
  else:
    raise ValueError((f'Unrecognised logging_mode "{logging_mode}". '
                      'Must be "csv" or "terminal".'))


def load_and_record_to_csv(bsuite_id: str,
                           results_dir: str,
                           overwrite: bool = False) -> dm_env.Environment:
  """Returns a bsuite environment that saves results to CSV.

  To load the results, specify the file path in the provided notebook, or to
  manually inspect the results use:

  ```python
  from bsuite.logging import csv_load

  results_df, sweep_vars = csv_load.load_bsuite(results_dir)
  ```

  Args:
    bsuite_id: The bsuite id identifying the environment to return. For example,
      "catch/0" or "deep_sea/3".
    results_dir: Path to the directory to store the resultant CSV files. Note
      that this logger will generate a separate CSV file for each bsuite_id.
    overwrite: Whether to overwrite existing CSV files if found.

  Returns:
    A bsuite environment determined by the bsuite_id.
  """
  raw_env = load_from_id(bsuite_id)
  termcolor.cprint(
      f'Logging results to CSV file for each bsuite_id in {results_dir}.',
      color='yellow',
      attrs=['bold'])
  return csv_logging.wrap_environment(
      env=raw_env,
      bsuite_id=bsuite_id,
      results_dir=results_dir,
      overwrite=overwrite,
  )


def load_and_record_to_terminal(bsuite_id: str) -> dm_env.Environment:
  """Returns a bsuite environment that logs to terminal."""
  raw_env = load_from_id(bsuite_id)
  termcolor.cprint(
      'Logging results to terminal.', color='yellow', attrs=['bold'])
  return terminal_logging.wrap_environment(raw_env)
